{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 17,
     "status": "ok",
     "timestamp": 1649957428444,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "WGYnB9z5GEne"
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "from collections import namedtuple\n",
    "import itertools\n",
    "from itertools import count\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.distributions.normal import Normal\n",
    "import numpy as np\n",
    "import collections\n",
    "import random\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "executionInfo": {
     "elapsed": 17,
     "status": "ok",
     "timestamp": 1649957428445,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "z8M3b0CiGEnj"
   },
   "outputs": [],
   "source": [
    "class PolicyNet(torch.nn.Module):\n",
    "    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):\n",
    "        super(PolicyNet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)\n",
    "        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)\n",
    "        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)\n",
    "        self.action_bound = action_bound\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        mu = self.fc_mu(x)\n",
    "        std = F.softplus(self.fc_std(x))\n",
    "        dist = Normal(mu, std)\n",
    "        normal_sample = dist.rsample()  # rsample()是重参数化采样函数\n",
    "        log_prob = dist.log_prob(normal_sample)\n",
    "        action = torch.tanh(normal_sample)  # 计算tanh_normal分布的对数概率密度\n",
    "        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)\n",
    "        action = action * self.action_bound\n",
    "        return action, log_prob\n",
    "\n",
    "\n",
    "class QValueNet(torch.nn.Module):\n",
    "    def __init__(self, state_dim, hidden_dim, action_dim):\n",
    "        super(QValueNet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)\n",
    "        self.fc2 = torch.nn.Linear(hidden_dim, 1)\n",
    "\n",
    "    def forward(self, x, a):\n",
    "        cat = torch.cat([x, a], dim=1)  # 拼接状态和动作\n",
    "        x = F.relu(self.fc1(cat))\n",
    "        return self.fc2(x)\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\n",
    "    \"cpu\")\n",
    "\n",
    "\n",
    "class SAC:\n",
    "    ''' 处理连续动作的SAC算法 '''\n",
    "    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,\n",
    "                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma):\n",
    "        self.actor = PolicyNet(state_dim, hidden_dim, action_dim,\n",
    "                               action_bound).to(device)  # 策略网络\n",
    "        # 第一个Q网络\n",
    "        self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)\n",
    "        # 第二个Q网络\n",
    "        self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)\n",
    "        self.target_critic_1 = QValueNet(state_dim, hidden_dim,\n",
    "                                         action_dim).to(device)  # 第一个目标Q网络\n",
    "        self.target_critic_2 = QValueNet(state_dim, hidden_dim,\n",
    "                                         action_dim).to(device)  # 第二个目标Q网络\n",
    "        # 令目标Q网络的初始参数和Q网络一样\n",
    "        self.target_critic_1.load_state_dict(self.critic_1.state_dict())\n",
    "        self.target_critic_2.load_state_dict(self.critic_2.state_dict())\n",
    "        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),\n",
    "                                                lr=actor_lr)\n",
    "        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),\n",
    "                                                   lr=critic_lr)\n",
    "        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),\n",
    "                                                   lr=critic_lr)\n",
    "        # 使用alpha的log值,可以使训练结果比较稳定\n",
    "        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)\n",
    "        self.log_alpha.requires_grad = True  # 可以对alpha求梯度\n",
    "        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],\n",
    "                                                    lr=alpha_lr)\n",
    "        self.target_entropy = target_entropy  # 目标熵的大小\n",
    "        self.gamma = gamma\n",
    "        self.tau = tau\n",
    "\n",
    "    def take_action(self, state):\n",
    "        state = torch.tensor([state], dtype=torch.float).to(device)\n",
    "        action = self.actor(state)[0]\n",
    "        return [action.item()]\n",
    "\n",
    "    def calc_target(self, rewards, next_states, dones):  # 计算目标Q值\n",
    "        next_actions, log_prob = self.actor(next_states)\n",
    "        entropy = -log_prob\n",
    "        q1_value = self.target_critic_1(next_states, next_actions)\n",
    "        q2_value = self.target_critic_2(next_states, next_actions)\n",
    "        next_value = torch.min(q1_value,\n",
    "                               q2_value) + self.log_alpha.exp() * entropy\n",
    "        td_target = rewards + self.gamma * next_value * (1 - dones)\n",
    "        return td_target\n",
    "\n",
    "    def soft_update(self, net, target_net):\n",
    "        for param_target, param in zip(target_net.parameters(),\n",
    "                                       net.parameters()):\n",
    "            param_target.data.copy_(param_target.data * (1.0 - self.tau) +\n",
    "                                    param.data * self.tau)\n",
    "\n",
    "    def update(self, transition_dict):\n",
    "        states = torch.tensor(transition_dict['states'],\n",
    "                              dtype=torch.float).to(device)\n",
    "        actions = torch.tensor(transition_dict['actions'],\n",
    "                               dtype=torch.float).view(-1, 1).to(device)\n",
    "        rewards = torch.tensor(transition_dict['rewards'],\n",
    "                               dtype=torch.float).view(-1, 1).to(device)\n",
    "        next_states = torch.tensor(transition_dict['next_states'],\n",
    "                                   dtype=torch.float).to(device)\n",
    "        dones = torch.tensor(transition_dict['dones'],\n",
    "                             dtype=torch.float).view(-1, 1).to(device)\n",
    "        rewards = (rewards + 8.0) / 8.0  # 对倒立摆环境的奖励进行重塑\n",
    "\n",
    "        # 更新两个Q网络\n",
    "        td_target = self.calc_target(rewards, next_states, dones)\n",
    "        critic_1_loss = torch.mean(\n",
    "            F.mse_loss(self.critic_1(states, actions), td_target.detach()))\n",
    "        critic_2_loss = torch.mean(\n",
    "            F.mse_loss(self.critic_2(states, actions), td_target.detach()))\n",
    "        self.critic_1_optimizer.zero_grad()\n",
    "        critic_1_loss.backward()\n",
    "        self.critic_1_optimizer.step()\n",
    "        self.critic_2_optimizer.zero_grad()\n",
    "        critic_2_loss.backward()\n",
    "        self.critic_2_optimizer.step()\n",
    "\n",
    "        # 更新策略网络\n",
    "        new_actions, log_prob = self.actor(states)\n",
    "        entropy = -log_prob\n",
    "        q1_value = self.critic_1(states, new_actions)\n",
    "        q2_value = self.critic_2(states, new_actions)\n",
    "        actor_loss = torch.mean(-self.log_alpha.exp() * entropy -\n",
    "                                torch.min(q1_value, q2_value))\n",
    "        self.actor_optimizer.zero_grad()\n",
    "        actor_loss.backward()\n",
    "        self.actor_optimizer.step()\n",
    "\n",
    "        # 更新alpha值\n",
    "        alpha_loss = torch.mean(\n",
    "            (entropy - target_entropy).detach() * self.log_alpha.exp())\n",
    "        self.log_alpha_optimizer.zero_grad()\n",
    "        alpha_loss.backward()\n",
    "        self.log_alpha_optimizer.step()\n",
    "\n",
    "        self.soft_update(self.critic_1, self.target_critic_1)\n",
    "        self.soft_update(self.critic_2, self.target_critic_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 17,
     "status": "ok",
     "timestamp": 1649957428446,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "xfK4N1doGEnl"
   },
   "outputs": [],
   "source": [
    "class Swish(nn.Module):\n",
    "    ''' Swish激活函数 '''\n",
    "    def __init__(self):\n",
    "        super(Swish, self).__init__()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(x)\n",
    "\n",
    "\n",
    "def init_weights(m):\n",
    "    ''' 初始化模型权重 '''\n",
    "    def truncated_normal_init(t, mean=0.0, std=0.01):\n",
    "        torch.nn.init.normal_(t, mean=mean, std=std)\n",
    "        while True:\n",
    "            cond = (t < mean - 2 * std) | (t > mean + 2 * std)\n",
    "            if not torch.sum(cond):\n",
    "                break\n",
    "            t = torch.where(\n",
    "                cond,\n",
    "                torch.nn.init.normal_(torch.ones(t.shape, device=device),\n",
    "                                      mean=mean,\n",
    "                                      std=std), t)\n",
    "        return t\n",
    "\n",
    "    if type(m) == nn.Linear or isinstance(m, FCLayer):\n",
    "        truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m._input_dim)))\n",
    "        m.bias.data.fill_(0.0)\n",
    "\n",
    "\n",
    "class FCLayer(nn.Module):\n",
    "    ''' 集成之后的全连接层 '''\n",
    "    def __init__(self, input_dim, output_dim, ensemble_size, activation):\n",
    "        super(FCLayer, self).__init__()\n",
    "        self._input_dim, self._output_dim = input_dim, output_dim\n",
    "        self.weight = nn.Parameter(\n",
    "            torch.Tensor(ensemble_size, input_dim, output_dim).to(device))\n",
    "        self._activation = activation\n",
    "        self.bias = nn.Parameter(\n",
    "            torch.Tensor(ensemble_size, output_dim).to(device))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self._activation(\n",
    "            torch.add(torch.bmm(x, self.weight), self.bias[:, None, :]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 779,
     "status": "ok",
     "timestamp": 1649957441286,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "o8OfdjXJGEnm"
   },
   "outputs": [],
   "source": [
    "class EnsembleModel(nn.Module):\n",
    "    ''' 环境模型集成 '''\n",
    "    def __init__(self,\n",
    "                 state_dim,\n",
    "                 action_dim,\n",
    "                 model_alpha,\n",
    "                 ensemble_size=5,\n",
    "                 learning_rate=1e-3):\n",
    "        super(EnsembleModel, self).__init__()\n",
    "        # 输出包括均值和方差,因此是状态与奖励维度之和的两倍\n",
    "        self._output_dim = (state_dim + 1) * 2\n",
    "        self._model_alpha = model_alpha  # 模型损失函数中加权时的权重\n",
    "        self._max_logvar = nn.Parameter((torch.ones(\n",
    "            (1, self._output_dim // 2)).float() / 2).to(device),\n",
    "                                        requires_grad=False)\n",
    "        self._min_logvar = nn.Parameter((-torch.ones(\n",
    "            (1, self._output_dim // 2)).float() * 10).to(device),\n",
    "                                        requires_grad=False)\n",
    "\n",
    "        self.layer1 = FCLayer(state_dim + action_dim, 200, ensemble_size,\n",
    "                              Swish())\n",
    "        self.layer2 = FCLayer(200, 200, ensemble_size, Swish())\n",
    "        self.layer3 = FCLayer(200, 200, ensemble_size, Swish())\n",
    "        self.layer4 = FCLayer(200, 200, ensemble_size, Swish())\n",
    "        self.layer5 = FCLayer(200, self._output_dim, ensemble_size,\n",
    "                              nn.Identity())\n",
    "        self.apply(init_weights)  # 初始化环境模型中的参数\n",
    "        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)\n",
    "\n",
    "    def forward(self, x, return_log_var=False):\n",
    "        ret = self.layer5(self.layer4(self.layer3(self.layer2(\n",
    "            self.layer1(x)))))\n",
    "        mean = ret[:, :, :self._output_dim // 2]\n",
    "        # 在PETS算法中,将方差控制在最小值和最大值之间\n",
    "        logvar = self._max_logvar - F.softplus(\n",
    "            self._max_logvar - ret[:, :, self._output_dim // 2:])\n",
    "        logvar = self._min_logvar + F.softplus(logvar - self._min_logvar)\n",
    "        return mean, logvar if return_log_var else torch.exp(logvar)\n",
    "\n",
    "    def loss(self, mean, logvar, labels, use_var_loss=True):\n",
    "        inverse_var = torch.exp(-logvar)\n",
    "        if use_var_loss:\n",
    "            mse_loss = torch.mean(torch.mean(torch.pow(mean - labels, 2) *\n",
    "                                             inverse_var,\n",
    "                                             dim=-1),\n",
    "                                  dim=-1)\n",
    "            var_loss = torch.mean(torch.mean(logvar, dim=-1), dim=-1)\n",
    "            total_loss = torch.sum(mse_loss) + torch.sum(var_loss)\n",
    "        else:\n",
    "            mse_loss = torch.mean(torch.pow(mean - labels, 2), dim=(1, 2))\n",
    "            total_loss = torch.sum(mse_loss)\n",
    "        return total_loss, mse_loss\n",
    "\n",
    "    def train(self, loss):\n",
    "        self.optimizer.zero_grad()\n",
    "        loss += self._model_alpha * torch.sum(\n",
    "            self._max_logvar) - self._model_alpha * torch.sum(self._min_logvar)\n",
    "        loss.backward()\n",
    "        self.optimizer.step()\n",
    "\n",
    "\n",
    "class EnsembleDynamicsModel:\n",
    "    ''' 环境模型集成,加入精细化的训练 '''\n",
    "    def __init__(self, state_dim, action_dim, model_alpha=0.01, num_network=5):\n",
    "        self._num_network = num_network\n",
    "        self._state_dim, self._action_dim = state_dim, action_dim\n",
    "        self.model = EnsembleModel(state_dim,\n",
    "                                   action_dim,\n",
    "                                   model_alpha,\n",
    "                                   ensemble_size=num_network)\n",
    "        self._epoch_since_last_update = 0\n",
    "\n",
    "    def train(self,\n",
    "              inputs,\n",
    "              labels,\n",
    "              batch_size=64,\n",
    "              holdout_ratio=0.1,\n",
    "              max_iter=20):\n",
    "        # 设置训练集与验证集\n",
    "        permutation = np.random.permutation(inputs.shape[0])\n",
    "        inputs, labels = inputs[permutation], labels[permutation]\n",
    "        num_holdout = int(inputs.shape[0] * holdout_ratio)\n",
    "        train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]\n",
    "        holdout_inputs, holdout_labels = inputs[:\n",
    "                                                num_holdout], labels[:\n",
    "                                                                     num_holdout]\n",
    "        holdout_inputs = torch.from_numpy(holdout_inputs).float().to(device)\n",
    "        holdout_labels = torch.from_numpy(holdout_labels).float().to(device)\n",
    "        holdout_inputs = holdout_inputs[None, :, :].repeat(\n",
    "            [self._num_network, 1, 1])\n",
    "        holdout_labels = holdout_labels[None, :, :].repeat(\n",
    "            [self._num_network, 1, 1])\n",
    "\n",
    "        # 保留最好的结果\n",
    "        self._snapshots = {i: (None, 1e10) for i in range(self._num_network)}\n",
    "\n",
    "        for epoch in itertools.count():\n",
    "            # 定义每一个网络的训练数据\n",
    "            train_index = np.vstack([\n",
    "                np.random.permutation(train_inputs.shape[0])\n",
    "                for _ in range(self._num_network)\n",
    "            ])\n",
    "            # 所有真实数据都用来训练\n",
    "            for batch_start_pos in range(0, train_inputs.shape[0], batch_size):\n",
    "                batch_index = train_index[:, batch_start_pos:batch_start_pos +\n",
    "                                          batch_size]\n",
    "                train_input = torch.from_numpy(\n",
    "                    train_inputs[batch_index]).float().to(device)\n",
    "                train_label = torch.from_numpy(\n",
    "                    train_labels[batch_index]).float().to(device)\n",
    "\n",
    "                mean, logvar = self.model(train_input, return_log_var=True)\n",
    "                loss, _ = self.model.loss(mean, logvar, train_label)\n",
    "                self.model.train(loss)\n",
    "\n",
    "            with torch.no_grad():\n",
    "                mean, logvar = self.model(holdout_inputs, return_log_var=True)\n",
    "                _, holdout_losses = self.model.loss(mean,\n",
    "                                                    logvar,\n",
    "                                                    holdout_labels,\n",
    "                                                    use_var_loss=False)\n",
    "                holdout_losses = holdout_losses.cpu()\n",
    "                break_condition = self._save_best(epoch, holdout_losses)\n",
    "                if break_condition or epoch > max_iter:  # 结束训练\n",
    "                    break\n",
    "\n",
    "    def _save_best(self, epoch, losses, threshold=0.1):\n",
    "        updated = False\n",
    "        for i in range(len(losses)):\n",
    "            current = losses[i]\n",
    "            _, best = self._snapshots[i]\n",
    "            improvement = (best - current) / best\n",
    "            if improvement > threshold:\n",
    "                self._snapshots[i] = (epoch, current)\n",
    "                updated = True\n",
    "        self._epoch_since_last_update = 0 if updated else self._epoch_since_last_update + 1\n",
    "        return self._epoch_since_last_update > 5\n",
    "\n",
    "    def predict(self, inputs, batch_size=64):\n",
    "        inputs = np.tile(inputs, (self._num_network, 1, 1))\n",
    "        inputs = torch.tensor(inputs, dtype=torch.float).to(device)\n",
    "        mean, var = self.model(inputs, return_log_var=False)\n",
    "        return mean.detach().cpu().numpy(), var.detach().cpu().numpy()\n",
    "\n",
    "\n",
    "class FakeEnv:\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "\n",
    "    def step(self, obs, act):\n",
    "        inputs = np.concatenate((obs, act), axis=-1)\n",
    "        ensemble_model_means, ensemble_model_vars = self.model.predict(inputs)\n",
    "        ensemble_model_means[:, :, 1:] += obs\n",
    "        ensemble_model_stds = np.sqrt(ensemble_model_vars)\n",
    "        ensemble_samples = ensemble_model_means + np.random.normal(\n",
    "            size=ensemble_model_means.shape) * ensemble_model_stds\n",
    "\n",
    "        num_models, batch_size, _ = ensemble_model_means.shape\n",
    "        models_to_use = np.random.choice(\n",
    "            [i for i in range(self.model._num_network)], size=batch_size)\n",
    "        batch_inds = np.arange(0, batch_size)\n",
    "        samples = ensemble_samples[models_to_use, batch_inds]\n",
    "        rewards, next_obs = samples[:, :1][0][0], samples[:, 1:][0]\n",
    "        return rewards, next_obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "executionInfo": {
     "elapsed": 636,
     "status": "ok",
     "timestamp": 1649957452282,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "T1X6ABP3GEno"
   },
   "outputs": [],
   "source": [
    "class MBPO:\n",
    "    def __init__(self, env, agent, fake_env, env_pool, model_pool,\n",
    "                 rollout_length, rollout_batch_size, real_ratio, num_episode):\n",
    "\n",
    "        self.env = env\n",
    "        self.agent = agent\n",
    "        self.fake_env = fake_env\n",
    "        self.env_pool = env_pool\n",
    "        self.model_pool = model_pool\n",
    "        self.rollout_length = rollout_length\n",
    "        self.rollout_batch_size = rollout_batch_size\n",
    "        self.real_ratio = real_ratio\n",
    "        self.num_episode = num_episode\n",
    "\n",
    "    def rollout_model(self):\n",
    "        observations, _, _, _, _ = self.env_pool.sample(\n",
    "            self.rollout_batch_size)\n",
    "        for obs in observations:\n",
    "            for i in range(self.rollout_length):\n",
    "                action = self.agent.take_action(obs)\n",
    "                reward, next_obs = self.fake_env.step(obs, action)\n",
    "                self.model_pool.add(obs, action, reward, next_obs, False)\n",
    "                obs = next_obs\n",
    "\n",
    "    def update_agent(self, policy_train_batch_size=64):\n",
    "        env_batch_size = int(policy_train_batch_size * self.real_ratio)\n",
    "        model_batch_size = policy_train_batch_size - env_batch_size\n",
    "        for epoch in range(10):\n",
    "            env_obs, env_action, env_reward, env_next_obs, env_done = self.env_pool.sample(\n",
    "                env_batch_size)\n",
    "            if self.model_pool.size() > 0:\n",
    "                model_obs, model_action, model_reward, model_next_obs, model_done = self.model_pool.sample(\n",
    "                    model_batch_size)\n",
    "                obs = np.concatenate((env_obs, model_obs), axis=0)\n",
    "                action = np.concatenate((env_action, model_action), axis=0)\n",
    "                next_obs = np.concatenate((env_next_obs, model_next_obs),\n",
    "                                          axis=0)\n",
    "                reward = np.concatenate((env_reward, model_reward), axis=0)\n",
    "                done = np.concatenate((env_done, model_done), axis=0)\n",
    "            else:\n",
    "                obs, action, next_obs, reward, done = env_obs, env_action, env_next_obs, env_reward, env_done\n",
    "            transition_dict = {\n",
    "                'states': obs,\n",
    "                'actions': action,\n",
    "                'next_states': next_obs,\n",
    "                'rewards': reward,\n",
    "                'dones': done\n",
    "            }\n",
    "            self.agent.update(transition_dict)\n",
    "\n",
    "    def train_model(self):\n",
    "        obs, action, reward, next_obs, done = self.env_pool.return_all_samples(\n",
    "        )\n",
    "        inputs = np.concatenate((obs, action), axis=-1)\n",
    "        reward = np.array(reward)\n",
    "        labels = np.concatenate(\n",
    "            (np.reshape(reward, (reward.shape[0], -1)), next_obs - obs),\n",
    "            axis=-1)\n",
    "        self.fake_env.model.train(inputs, labels)\n",
    "\n",
    "    def explore(self):\n",
    "        obs, done, episode_return = self.env.reset(), False, 0\n",
    "        while not done:\n",
    "            action = self.agent.take_action(obs)\n",
    "            next_obs, reward, done, _ = self.env.step(action)\n",
    "            self.env_pool.add(obs, action, reward, next_obs, done)\n",
    "            obs = next_obs\n",
    "            episode_return += reward\n",
    "        return episode_return\n",
    "\n",
    "    def train(self):\n",
    "        return_list = []\n",
    "        explore_return = self.explore()  # 随机探索采取数据\n",
    "        print('episode: 1, return: %d' % explore_return)\n",
    "        return_list.append(explore_return)\n",
    "\n",
    "        for i_episode in range(self.num_episode - 1):\n",
    "            obs, done, episode_return = self.env.reset(), False, 0\n",
    "            step = 0\n",
    "            while not done:\n",
    "                if step % 50 == 0:\n",
    "                    self.train_model()\n",
    "                    self.rollout_model()\n",
    "                action = self.agent.take_action(obs)\n",
    "                next_obs, reward, done, _ = self.env.step(action)\n",
    "                self.env_pool.add(obs, action, reward, next_obs, done)\n",
    "                obs = next_obs\n",
    "                episode_return += reward\n",
    "\n",
    "                self.update_agent()\n",
    "                step += 1\n",
    "            return_list.append(episode_return)\n",
    "            print('episode: %d, return: %d' % (i_episode + 2, episode_return))\n",
    "        return return_list\n",
    "\n",
    "\n",
    "class ReplayBuffer:\n",
    "    def __init__(self, capacity):\n",
    "        self.buffer = collections.deque(maxlen=capacity)\n",
    "\n",
    "    def add(self, state, action, reward, next_state, done):\n",
    "        self.buffer.append((state, action, reward, next_state, done))\n",
    "\n",
    "    def size(self):\n",
    "        return len(self.buffer)\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        if batch_size > len(self.buffer):\n",
    "            return self.return_all_samples()\n",
    "        else:\n",
    "            transitions = random.sample(self.buffer, batch_size)\n",
    "            state, action, reward, next_state, done = zip(*transitions)\n",
    "            return np.array(state), action, reward, np.array(next_state), done\n",
    "\n",
    "    def return_all_samples(self):\n",
    "        all_transitions = list(self.buffer)\n",
    "        state, action, reward, next_state, done = zip(*all_transitions)\n",
    "        return np.array(state), action, reward, np.array(next_state), done"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 680
    },
    "executionInfo": {
     "elapsed": 613836,
     "status": "ok",
     "timestamp": 1649958070782,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "_gcY5HvTGEnr",
    "outputId": "49c828a2-35ec-44d9-f952-a52e01e46fe8"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:59: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 1, return: -1617\n",
      "episode: 2, return: -1463\n",
      "episode: 3, return: -1407\n",
      "episode: 4, return: -929\n",
      "episode: 5, return: -860\n",
      "episode: 6, return: -643\n",
      "episode: 7, return: -128\n",
      "episode: 8, return: -368\n",
      "episode: 9, return: -118\n",
      "episode: 10, return: -123\n",
      "episode: 11, return: -122\n",
      "episode: 12, return: -118\n",
      "episode: 13, return: -119\n",
      "episode: 14, return: -119\n",
      "episode: 15, return: -121\n",
      "episode: 16, return: -123\n",
      "episode: 17, return: 0\n",
      "episode: 18, return: -125\n",
      "episode: 19, return: -126\n",
      "episode: 20, return: -243\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEWCAYAAACjYXoKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU1f3/8dc7CUnYd0H2xaCCK0YE69aKCrYW14pWRWtLbbXbt7Xq1y528fe1trWtttXiUtHWqq1V+X7V4lYLakAWEUEEQthlycK+ZJvP7497o2NMQshk5k6Sz/PxmMfcOefM3M9MkvnknnPvOTIznHPOuURkRB2Ac865ls+TiXPOuYR5MnHOOZcwTybOOecS5snEOedcwjyZOOecS5gnE+daGUlnSNrQ3G2da4gnE5cWJK2RVCGpV63ytyWZpCHh44fDdrsl7ZK0QNLpce2vllQd1u+UtEjS5+Lqu0m6V9JmSXslvSvpmhS9xzMkxeJiX56qfbcEki6XtFbSHknPSOoRdUyu8TyZuHSyGris5oGko4EOdbS708w6AV2Ae4F/SsqMqy8I67sBDwJPSuouKRt4GRgMjAO6AjcCd0j6r2S8oTp8EBf7TcD9kkamaN9pS9Io4E/AlUAfYC/wx0iDcgfFk4lLJ48CV8U9ngI8Ul9jC6ZveAzoQfAFVLs+BjwEtAeGE3xRDQIuMbPVZlZpZv8Cvgn8VFKXuvYj6WRJ8yTtCO9Pjqt7TdLPJL0RHm28WPvoqr7YzewZYBswUlKGpJslrZJUKunJmv/MJQ0Jj86mSFonqUTSrXExtA+P2LZJeg84sVb8JumwuMcPS/p5Pe+13rY1XWKSvi9pq6RNks6XdK6kFZLKJP13Pa/bT9K++KMNSceH76Ud8EXgf81slpntBn4IXCip84E+S5cePJm4dDIH6CLpyPBIYzLwl/oah22uIjii2VJHfRbwZWA3sBI4C3jBzPbUavoUkEtwtFL7NXoAzwF3Az2Bu4DnJPWMa3Y5cA1wCJANfO9AbzRMHhcQHD29C3wDOB84HehHkGT+UOtppwCHA2cCP5J0ZFj+Y4JkORw4hyAJJ0tfgs+qP/Aj4H7gCuAE4FTgh5KG1n6SmX0AFAAXxRVfDvzDzCqBUcA7ce1XARXAiOS8DdfcPJm4dFNzdHIWsAzYWEeb70naTpAkfgv80Myq4+rHhvWbCbrNLjCzHUAvYFPtFzOzKqAkrK/ts8BKM3vUzKrM7G/A+8B5cW3+bGYrzGwf8CRwXAPvr18YWwlBErjSzJYD1wG3mtkGMysHbgMuDhNijZ+Y2T4ze4fgi/fYsPwLwO1mVmZm6wkSX7JUhvuqBB4n+Mx+Z2a7zGwp8F5cXLU9RtiNKUkE/yw8FtZ1AnbUar8D8COTFiLrwE2cS6lHgVnAUOrv4vqVmf0g/EIaBbwoqczMXgjr55jZKXU8rwQ4tHZh+IXdK6yvrR+wtlbZWoL/zGtsjtveS/DFWJ8PzGxAHeWDgaclxeLKqvl49119++kHrK8VX7KUxiXufeF9/FHhvpq4JO2OKx9JcAR4j6RDCY44YsDssH43wThSvC7AruYL3SWTH5m4tGJmawm6rc4F/nmAtmZmS4A3CI4gDuRlYKKkjrXKLwLKCbrZavuA4Is+3iDqPmJKxHpgopl1i7vlmllj9rMJGFgrvnh7+fiJDH0beK2DadsgM+sUd1tnZtuAF4FLCbq4HrePpi1fStwRjaRhQA6woqn7d6nlycSlo2uBz9QxtvEJko4gGEtY2ojXfRTYAPw9HNRuJ+kcgm6h28KusNqeB0aEp61mSbqU4L/s/2vsm2mk+4DbJQ0GkNRb0qRGPvdJ4JbwjLUBBOMv8RYBl0vKlDSBYFymPgfTtikeI+jGvJiPurgA/gqcJ+nUMNn/FPinmfmRSQvhycSlHTNbZWbzG2jy/fBajT0E/+n+meC00gO9bjkwnuAoYC6wk2BA/VYz+2U9zykFPgd8FygFvg98zszq6hJLxO+AGQRddrsIjpJOauRzf0LQtbWa4PN4tFb9twjGeLYTnDX1TAOvdTBtm2IGkAdsDsd+AAjHW64jSCpbCcZKvt7M+3ZJJF8cyznnXKL8yMQ551zCPJk455xLmCcT55xzCfNk4pxzLmFt9qLFXr162ZAhQ6IOwznnWpQFCxaUmFnv2uVtNpkMGTKE+fMbOvvUOedcbZLqnGHBu7mcc84lzJOJc865hHkycc45lzBPJs455xLmycQ551zCWk0ykTRB0nJJhZJujjoe55xrS1pFMgmXb/0DMJFgevDLJI2MNirnnGs7WkUyAcYAhWZWZGYVBMuJNnYtCOeca5TNO/bz2Nx1VFbHDty4jWktFy325+PLlm6gjrUgJE0FpgIMGlR7MTrnnKtfUfFurnhgLh/s2M9ry7dyz+XHk5OVGXVYaaO1HJk0iplNM7N8M8vv3fsTswE451ydln6wg0vuK6C8KsZXTx/Gi+9t4cvT57Ovojrq0NJGa0kmG/n4GtgDaP41up1zbdD8NWVMnjaHnKwMnrxuHLdMPJI7LzqGNwpLmPLQW+zaXxl1iGmhtSSTeUCepKGSsoHJBMuDOudck722fCtXPDiX3p1y+PvXTmZ4704AfOHEgfxu8vEsXLeNKx6Yy/a9FRFHGr1WkUzMrAq4AZgJLAOeDNeUds65Jnlu8Sa+8sh8hvXqxJPXjaN/t/Yfqz/v2H7cd8UJLNu0i8nT5lC8qzyiSNNDq0gmAGb2vJmNMLPhZnZ71PE451quJ+at4xt/W8ixA7rxt6lj6dUpp85240f24aGrT2Rt6V4u/VMBH2zfl+JI00erSSbOOdcc7p9VxE1Pvcupeb159NqT6Nq+XYPtT8nrxaPXjqF4VzmX3FfA2tI9KYo0vXgycc45wMz49YvLuf35ZXz26EO5/6p82mc37tTf/CE9eOwrY9lbUcUl9xWwcsuuJEebfjyZOOfavFjMuG3GUu55tZBL8wdy92XHk511cF+PRw/oyhNfHYcBl06bw5KNO5ITbJryZOKca9Mqq2N89+/vML1gLVNPG8YdFx1NZoaa9Foj+nTm718dR/t2mVx2/xwWrC1r5mjTlycT51ybtb+ymq/9ZSFPv72RG885nFsmHoHUtERSY0ivjjx53Th6dcrhygff4s3CkmaKNr15MnHOtUm7y6v40sPzeHnZFn46aRTXf/qwhBNJjf7d2vPEV8cysHsHrn54Hq++v6VZXjedeTJxzrU52/dW8MUH5jJ3dRm/ufRYrho3pNn3cUjnXB6fOpbD+3Rm6iMLeG7xpmbfRzppLRM9OpcSZkbZngqyszLIbZdJu8z0/3/MzCivilFRHaO8MkZVLEaGFN4I7jOC7cyMoFyCTH203Vz/saeDLTv3c+WDc1lTupf7rjiBs0b2Sdq+unfM5q9fOYlrH57HN/62kL0Vx3BJ/sADP7EF8mTi3EH4+XPLePD11R8+zswQuWFiyam5b5dJbrsMcrMyyQnvc9sFdUECEjEDMzAsuDfDgJgFj2MGULMdtgnrYzH7MDGUV8WoqIpRXlVNeVWtx5UxyquDx4n6MOmEyaWGxTey+M2P1WAff0hmhsjKUHCfmfHxxx/eh+WZnyzPyBCZgsyMDDIzgtfLzMggU5AR1zZDCutEZrj9wpLNlO4u5+FrTuTk4b0S/mwOpEtuO6Z/aQxffXQBN/5jMcW7yxk9qDsiiFUEyTpDcfcEn3PN5/3RPXRp345DOucmPe6D5cnEtSiPzlnLK8u28OCUE5t8xk1TrS/byyMFa/j04b05eXgvyquq2V8ZY39lNfvjtys/+jIv21PxYVlwX01VzD78ApGo88sEPvryiP9igeCLMycrg+ysDHKyMsjJyqRbh+xgu10m2ZkZ5LTL+PA+JyszbBfcMjMywgRlxMJkVR37KHFV1yS0WLAdC5NddeyjbeI+esU90MfKP66mzgyqzaiuNqpiwesG97GPP66uu3xvRRXVNfHV3MIkW9MuZsF2zXuorg7vY0avTjn89StjOW5gt+b/JalHh+ws7r8qnxsee5s7/7U8odeS4LbzRjHl5CHNE1wz8WTiWpS/z1/P4g07eGrhBr6Q4u6C372yEkn8z4XH0Ldr+v1n6NJbbrtM/nTlCSxav53yqmoIj0CNj5J0/JFoLDxatQ8TelD2zNsb+fGMpWzZuZ8bzzk8bbogPZm4FmPn/soPLwS768UVnHdMv0ZfoZyolVt28c+FG7j2lKGeSFyTZWaIEwZ3T+g1Jh7Vlx8+u4Q/vraKLTvLueOio9Ni7C76CJxrpLeKyogZfHt8Hpt37ufPb64+8JOayV0vraB9u0y+dsZhKdunc3XJyszg/11wNN8ZP4KnFm7gy9Pns6e8KuqwPJm4lqOgqJTsrAyuO3044488hHv/vYqyPclfR2Lxhu28sGQzXz51GD06Zid9f84diCS+NT6POy48mtkri7ns/jmU7I52CnxPJq7FKFhVyuhB3chtl8lNE45gT0UV97y6Mun7/dWLK+jeoR1fPnVo0vfl3MGYPGYQ067MZ8WWXVx075usKYluxuK0SyaSfinpfUmLJT0tqVtc3S2SCiUtl3ROXPmEsKxQ0s3RRO6SafveCpZt3sm4YcGpnHl9OvOF/IH8Zc5a1pXuTdp+5xSVMmtFMV8/4zA65zY8FblzURg/sg+PfWUsO/dVctG9b7J4w/ZI4ki7ZAK8BBxlZscAK4BbACSNJFiOdxQwAfijpExJmcAfgInASOCysK1rReYUlWEG44b3/LDsO2eNIDND/PLFxE61rI+Z8cuZy+nTJYcrxw1Oyj6caw6jB3XnH187mdx2mUyeNofXlm9NeQxpl0zM7MVwGV6AOcCAcHsS8LiZlZvZaqAQGBPeCs2syMwqgMfDtq4VmVNUSm67DI4d2PXDsj5dcvnyKcP433c+SMp/Y/9evpUFa7fxzTPzyG2XmrPGnGuq4b078fTXT2ZIz458efp8/rFgQ0r3n3bJpJYvAS+E2/2B9XF1G8Ky+spdK1KwqpQTh/QgJ+vjX+pfPT0YFP+f598PLqZrJrGY8cuZKxjcs0PKr2dxrqkO6ZLLE18dy0nDevC9v7/DH/5d2Kx/Fw2JJJlIelnSkjpuk+La3ApUAX9txv1OlTRf0vzi4uLmelmXZKW7y1m+ZRdjh/X8RF3n3HZ88zOHUVBUymsrmu9n+ty7m1i2aSf/ddaItDiH37nG6pzbjj9fPYZJx/XjlzOX86Nnl1IdS35CieSiRTMb31C9pKuBzwFn2kdpdSMQ/y/igLCMBspr73caMA0gPz8/NenaJWxOUbDAUPx4SbzLTxrMn99cwx3Pv89peb0TnmalsjrGXS+t4Ii+nTnvmH4JvZZzUcjOyuA3XziOQzrncP/s1RTvKue3k49Landt2v3LJWkC8H3g82YWf5rODGCypBxJQ4E84C1gHpAnaaikbIJB+hmpjtslT0FRCR2zMzm6f9c667OzMrjxnMNZvmUXTy1MvJ/4qQUbWF2yh++efTgZKZ7/y7nmkpEhbv3sSH7w2SP519LNXPXgW+zYW5m8/SXtlZvu90Bn4CVJiyTdB2BmS4EngfeAfwHXm1l1OFh/AzATWAY8GbZ1rUTBqlJOHNqjwe6mzx59KMcO6MpdL65gf2V1k/e1v7Ka372ykuMHdWP8kYc0+XWcSxdfPnUYd192PIvWb+eSP73JB9v3JWU/aZdMzOwwMxtoZseFt+vi6m43s+FmdriZvRBX/ryZjQjrbo8mcpcMW3fuZ1XxHsbVMV4STxK3nHskm3fu56E3mj7Nyl/nrmPTjvSaQM+5RH3+2H48/KUT2bR9Pxf+8U1WFe9u9n2kXTJxLl5BUSlQ/3hJvLHDenLmEU2fZmV3eRV//HchpxzWKyXrXDiXSicP78UTXx1HXp9O9OqY0+yv78nEpbU5RaV0zs1iVL+6x0tqu2liMM3K718tPOh9PfT6akr3VPC9cw4/6Oc61xKM7NeFR689ia4dmn82B08mLq0VrCrlpKE9Gn2G1og+nbnkhIE8OmfNQU2zsm1PBffPKuLskX1SumiSc62FJxOXtjbt2Mea0r11Xl/SkJppVn51ENOs3DdrFbsrqvju2X5U4lxTeDJxaatgVePHS+L17RpMszKjkdOsbNm5n+lvruGC4/pzeN/OTYrVubbOk4lLWwWrSunWoR1H9u1y0M89mGlW7nl1JVXVxrfHj2hqqM61eZ5MXNoqKArGS5py4WBjp1lZV7qXx99az+QxAxnUs0Mi4TrXpnkycWlpfdleNmzbd8DrSxpy+UmDGdyzA7944f165yb67csryMwQ3/hMXpP345zzZOLS1EfXlzT9eo+aaVbe37yLf9YxzcqKLbt4etFGrj55CH265DZ5P845TyYuTc1ZVUrPjtmM6NMpodf5cJqVlz45zcqvZi6nU3YW150+PKF9OOc8mbg0ZGYUFJUydljPhKc0qZlmZdOO/fz5jTUfli9av50X39vCV04bRveO2QlG7JzzZOLSztrSvWzasZ+xB3lKcH1qpln542uFbAunWfnVzOX06JjNl04Z2iz7cK6t82Ti0s6bNdeXJDD4XttNE49gT3kVv/93IW8WlvB6YQlfP2M4nXIiWdLHuVbH/5Jc2ikoKqV35xyG9+7YbK9ZM83KIwVrmL2ymEO75nLF2MHN9vrOtXV+ZOLSiplRsKqUcc0wXlJbzTQrK7bs5ltn5iV11Tnn2hpPJi6trCreTcnu8oOeQqUx+nbN5aYJR3BqXi8uOmFAs7++c21Z2iYTSd+VZJJ6hY8l6W5JhZIWSxod13aKpJXhbUp0UbtEFSRhvCTeNZ8ayqPXntTgqo3OuYOXlmMmkgYCZwPr4oonEqz7ngecBNwLnCSpB/BjIB8wYIGkGWa2LbVRu+ZQUFTKoV1zGexTmzjXoqTrv2e/Ab5PkBxqTAIescAcoJukQ4FzgJfMrCxMIC8BE1IesUtYLGbMKSpLyniJcy650i6ZSJoEbDSzd2pV9QfWxz3eEJbVV17Xa0+VNF/S/OLi+if/c9FYsXUXZXsqmu36Eudc6kTSzSXpZaBvHVW3Av9N0MXV7MxsGjANID8/v+F5yV3KJXu8xDmXPJEkEzMbX1e5pKOBocA7YTfHAGChpDHARmBgXPMBYdlG4Ixa5a81e9Au6QpWlTKge3sG9vDxEudamrTq5jKzd83sEDMbYmZDCLqsRpvZZmAGcFV4VtdYYIeZbQJmAmdL6i6pO8FRzcyo3oNrmljMmLu6zI9KnGuh0vJsrno8D5wLFAJ7gWsAzKxM0s+AeWG7n5pZWTQhuqZ6b9NOduyrTMr1Jc655EvrZBIendRsG3B9Pe0eAh5KUVguCeYUNW29d+dcekirbi7XdhWsKmVIzw4c2rV91KE455rAk4mLXFV1jLdWl/lRiXMtmCcTF7mlH+xkV3kVY33w3bkWy5OJi9yH6717MnGuxfJk4iJXsKqU4b07ckiX3KhDcc41kScTF6nK6hjz1vh4iXMtnScTF6nFG3awt6KaccN6RR2Kcy4BnkxcpGquLxk7rEfEkTjnEuHJxEWqYFUph/fpTM9OOVGH4pxLgCcTF5nyqmrmr/XxEudaA08mLjLvrN/B/sqYX1/iXCvgycRFpmBVKZKPlzjXGngycZEpKCrhyL5d6NYhO+pQnHMJ8mTiIrG/spqF67b7eIlzrYQnExeJheu2UVEV8ylUnGsl0jKZSPqGpPclLZV0Z1z5LZIKJS2XdE5c+YSwrFDSzdFE7Q7GnFWlZAjG+HiJc61C2i2OJenTwCTgWDMrl3RIWD4SmAyMAvoBL0saET7tD8BZBMv8zpM0w8zeS330rrEKiko5qn9XuuS2izoU51wzSMcjk68Bd5hZOYCZbQ3LJwGPm1m5ma0mWL53THgrNLMiM6sAHg/bujS1r6KaReu3exeXc61IOiaTEcCpkuZK+o+kE8Py/sD6uHYbwrL6yl2amr+2jMpqY6wPvjvXakTSzSXpZaBvHVW3EsTUAxgLnAg8KWlYM+13KjAVYNCgQc3xkq4JClaVkpkhThzi4yXOtRaRJBMzG19fnaSvAf80MwPekhQDegEbgYFxTQeEZTRQXnu/04BpAPn5+dbkN+ASUlBUyjEDutIpJ+2G7JxzTZSO3VzPAJ8GCAfYs4ESYAYwWVKOpKFAHvAWMA/IkzRUUjbBIP2MSCJ3B7S7vIrFG3b4eIlzrUw6/mv4EPCQpCVABTAlPEpZKulJ4D2gCrjezKoBJN0AzAQygYfMbGk0obsDmbemjOqY+cWKzrUyaZdMwjOyrqin7nbg9jrKnweeT3Jorhn8Z3kx7TJF/mAfL3GuNUnHbi7XSq0p2cNjb63js0cfSvvszKjDcc41I08mLiXMjNv+dynZmRnccu6RUYfjnGtmnkxcSsxcuoXXlhfz7fF59OmSG3U4zrlm5snEJd3eiip+9n/vcUTfzlx98pCow3HOJUGjkomkb0nqosCDkhZKOjvZwbnW4fevFrJx+z5+OukosjL9/xfnWqPG/mV/ycx2AmcD3YErgTuSFpVrNQq37ub+2UVcOLo/Y4b6GVzOtVaNTSYK788FHg2v41AD7Z0LBt1nLCW3XSa3TPRBd+das8YmkwWSXiRIJjMldQZiyQvLtQbPvbuJ1wtL+N7Zh9O7c07U4TjnkqixFy1eCxwHFJnZXkk9gWuSF5Zr6XaXB4Puo/p14Yqxg6MOxzmXZI1KJmYWk7QFGCkp7a6ad+nn7ldWsmVnOX/84glkZniPqHOtXaMSg6RfAJcSzItVHRYbMCtJcbkWbMWWXTz0+mouzR/ICYO7Rx2Ocy4FGnuUcT5weM3qh87Vx8z44TNL6JiTxU0Tj4g6HOdcijR2AL4I8MW63QE9u+gD5q4u4/sTDqdHx+yow3HOpUhjj0z2AoskvQJ8eHRiZt9MSlSuRdq5v5Lbn1/GsQO6MvlEX8nSubaksclkBr7glDuA37y0gpLd5Tw4Jd8H3Z1rYw6YTCRlAleb2adTEI9rod77YCfT31zD5WMGccyAblGH45xLsQOOmYSrGcYkdU1BPEg6TtIcSYskzZc0JiyXpLslFUpaLGl03HOmSFoZ3qakIk73kVjM+NGzS+jWIZsbzzk86nCccxFobDfXbuBdSS8Be2oKkzRmcifwEzN7QdK54eMzgIkE677nAScB9wInSeoB/BjIJzhdeYGkGWa2LQmxuTo8tXAD89du486LjqFbBx90d64tamwy+Wd4SwUDuoTbXYEPwu1JwCPhevBzJHWTdChBonnJzMoAwoQ3AfhbiuJt03bsreSOF95n9KBuXHzCgKjDcc5FpLFXwE9PdiBxvk0w/9evCLrhTg7L+wPr49ptCMvqK/8ESVOBqQCDBvnZRs3hVy8uZ9veCh65dgwZPujuXJvV2CvgVxMcMXyMmQ1ryk4lvQz0raPqVuBM4Dtm9pSkLwAPAuObsp/azGwaMA0gPz//E+/HHZx3N+zgL3PXMmXcEEb1S8mQmnMuTTW2mys/bjsXuARo8uIUZlZvcpD0CPCt8OHfgQfC7Y3AwLimA8KyjQRdXfHlrzU1Ntc4sZjxw2eX0LNjDt85a0TU4TjnItaoK+DNrDTuttHMfgt8NkkxfQCcHm5/BlgZbs8ArgrP6hoL7DCzTcBM4GxJ3SV1J1jAa2aSYnOhJ+avZ9H67fz3uUfQtb1PjuBcW9fYbq7RcQ8zCI5UkjV78FeA34WzE+8nHOMAnidYT6WQ4Ir8awDMrEzSz4B5Ybuf1gzGu+TYtqeCX/zrfcYM6cEFx9c5POWca2MamxB+HbddBawGvtD84YCZvQ6cUEe5AdfX85yHgIeSEY/7pDtnvs+u/VX89PxRSD7o7pw7iMWxzKwovkDS0CTE49LcgrVlPD5vPdd+aihH9O1y4Cc459qExiaTfwCj6yj7xBGEa5227NzPPa+u5Il56+nTOZdvjc+LOiTnXBppMJlIOgIYBXSVdGFcVReCs7pcK1e2p4L7/rOK6W+uoTpmXHriQL55Zh6dc33Q3Tn3kQMdmRwOfA7oBpwXV76LYKDctVK79lfywOzVPPj6avZUVHHB8f359pkjGNSzQ9ShOefSUIPJxMyeBZ6VNM7MClIUk4vQvopqHilYw73/WcX2vZVMPKov/3XWCPL6dI46NOdcGmvsmElpuDBWHzM7StIxwOfN7OdJjM2lUEVVjCfmreOeVwvZuquc00b05ntnj/Dp5J1zjdLYZHI/cCPwJwAzWyzpMcCTSQtXHTOefnsjv315BRu27ePEId2557LjOWlYz6hDc861II1NJh3M7K1a1xRUJSEelyKxmPGvpZu566UVFG7dzVH9u/Dz84/i9BG9/doR59xBa2wyKZE0nHCyR0kXA5uSFpVLGjPjtRXF/PrF5SzZuJPDDunEvV8czYSj+noScc41WWOTyfUEs+0eIWkjwRXwX0xaVC5pfvPSCu5+tZAB3dvz60uO5fzj+/t67c65hDV2PZMiYLykjgRzc+0FJgNrkxiba2ZmxlMLN3JqXi8enHIi2VmNmufTOecOqMFvE0ldJN0i6feSziJIIlMIJltMytxcLnmKSvawcfs+zhnV1xOJc65ZHejI5FFgG1BAcJHirYCAC8xsUZJjc81s9opiAE4f0TviSJxzrc2BkskwMzsaQNIDBIPug8xsf9Ijc81u9soShvTswMAefhW7c655Haivo7Jmw8yqgQ2eSFqmiqoYBUWlnJrnRyXOueZ3oGRyrKSd4W0XcEzNtqSdTd2ppEskLZUUk5Rfq+4WSYWSlks6J658QlhWKOnmuPKhkuaG5U9Iym5qXK3ZwnXb2FtRzal5vaIOxTnXCjWYTMws08y6hLfOZpYVt53IYhZLgAuBWfGFkkYSnCU2CpgA/FFSpqRM4A/ARGAkcFnYFuAXwG/M7DCC8Z1rE4ir1Zq9spjMDDFuuF/Z7pxrfpGc0mNmy8xseR1Vk4DHzazczFYTnDU2JrwVmlmRmVUAjwOTFFxl9xmCtVUApgPnJ/8dtDyzVpQwelA3nzreOZcU6XZ+aH9gfdzjDWFZfeU9ge1mVlWrvE6SpkqaL2l+cXFxswaezkp3l7Pkg41thOIAABLASURBVB0+XuKcS5rGXgF/0CS9DPSto+rWcGr7lDOzaQRX8pOfn29RxBCFN1aVYgan+SnBzrkkSVoyMbPxTXjaRmBg3OMBYRn1lJcC3SRlhUcn8e1daPaKYrq2b8fR/btGHYpzrpVKt26uGcBkSTmShgJ5wFvAPCAvPHMrm2CQfoaZGfBv4OLw+VOASI560pWZMXtlCacc1svn4HLOJU0kyUTSBZI2AOOA5yTNBDCzpcCTwHvAv4Drzaw6POq4AZgJLAOeDNsC3AT8l6RCgjGUB1P7btJb4dbdbN65308Jds4lVdK6uRpiZk8DT9dTdztwex3lzwPP11FeRHC2l6vDf8IpVE7xZOKcS6J06+ZyzWz2yhKG9e7IgO4+hYpzLnk8mbRi+yurmbu6lNP8lGDnXJJ5MmnFFqzdxv7KGKeN8C4u51xyeTJpxWatLKZdpjhpqE+h4pxLLk8mrdjsFSWcMLg7HXMiOc/COdeGeDJppYp3lfPepp0+hYpzLiU8mbRSrxcGpwT74LtzLhU8mbRSs1eU0KNjNqP6JbJSgHPONY4nk1bIzJgVTqGS4VOoOOdSwJNJK/T+5l2U7C73KVSccynjyaQVmr0yGC/xwXfnXKp4MmmFZq8sYUSfTvTtmht1KM65NsKTSSuzr6KauavL/KjEOZdSnkxambfWlFFRFfPxEudcSnkyaWVmrygmOyvDp1BxzqWUJ5NWZvbKEsYM6UH77MyoQ3HOtSFRrbR4iaSlkmKS8uPKz5K0QNK74f1n4upOCMsLJd0tSWF5D0kvSVoZ3neP4j2lgy0797N8yy7v4nLOpVxURyZLgAuBWbXKS4DzzOxogvXcH42ruxf4CsG68HnAhLD8ZuAVM8sDXgkft0mzV5YAfkqwcy71IkkmZrbMzJbXUf62mX0QPlwKtJeUI+lQoIuZzTEzAx4Bzg/bTQKmh9vT48rbnFkriunVKYcj+naOOhTnXBuTzmMmFwELzawc6A9siKvbEJYB9DGzTeH2ZqBPfS8oaaqk+ZLmFxcXJyPmyMRixuuFJZya51OoOOdSL2kLXUh6GehbR9WtZvbsAZ47CvgFcPbB7NPMTJI1UD8NmAaQn59fb7uW6L1NOynbU+HjJc65SCQtmZjZ+KY8T9IA4GngKjNbFRZvBAbENRsQlgFskXSomW0Ku8O2NjXmlmxWOIXKKZ5MnHMRSKtuLkndgOeAm83sjZrysBtrp6Sx4VlcVwE1RzczCAbrCe8bPOpprWavKOHIQ7twSGefQsU5l3pRnRp8gaQNwDjgOUkzw6obgMOAH0laFN4OCeu+DjwAFAKrgBfC8juAsyStBMaHj9uUvRVVzF9bxml+VOKci0gki4Ob2dMEXVm1y38O/Lye58wHjqqjvBQ4s7ljbEnmFpVRWW1+SrBzLjJp1c3lmuY/K4rJycogf0ibvV7TORcxTyatwOyVxZw0rCe57XwKFedcNDyZtHAbt+9jVfEeHy9xzkXKk0kL93p4SvBpI3y8xDkXHU8mLdyslSX06ZJD3iGdog7FOdeGeTJpwapjxhuFJZya15twEmXnnIuEJ5MW7N2NO9i+t9KnUHHORc6TSQs2e0U4hcphnkycc9HyZNKCzV5ZwlH9u9CzU07UoTjn2jhPJi3Urv2VLFy3jdP8qnfnXBrwZNJCzSkqoyrmU6g459KDJ5MWavbKYjpkZzJ6cLeoQ3HOOU8mLdXslSWMHdaTnCyfQsU5Fz1PJi3Q+rK9rC7Z46cEO+fShieTFqhmVUUfL3HOpYuoFse6RNJSSTFJ+XXUD5K0W9L34somSFouqVDSzXHlQyXNDcufkJSdqvcRldkrSujXNZfhvTtGHYpzzgHRHZksAS4EZtVTfxcfraSIpEzgD8BEYCRwmaSRYfUvgN+Y2WHANuDaZAWdDqqqY7yxqoTTRvgUKs659BFJMjGzZWa2vK46SecDq4GlccVjgEIzKzKzCuBxYFK4HvxngH+E7aYD5ycv8ui9s2EHu/ZXeReXcy6tpNWYiaROwE3AT2pV9QfWxz3eEJb1BLabWVWt8vpef6qk+ZLmFxcXN1/gKVJZHeOZtzciwacO6xl1OM4596GkrQEv6WWgbx1Vt5rZs/U87TaCLqvdyejCMbNpwDSA/Px8a/YdJEnJ7nIef2sdf5mzjs0793PWyD5069Dqh4accy1I0pKJmY1vwtNOAi6WdCfQDYhJ2g8sAAbGtRsAbARKgW6SssKjk5ryVuHdDTv485ur+b93NlFRHePUvF78vwuP4owRh0QdmnPOfUzSkklTmNmpNduSbgN2m9nvJWUBeZKGEiSLycDlZmaS/g1cTDCOMgWo76inRaisjvHCks08/MZqFq7bTsfsTCaPGchV44ZwmC+A5ZxLU5EkE0kXAPcAvYHnJC0ys3Pqa29mVZJuAGYCmcBDZlYzQH8T8LiknwNvAw8mN/rkKN5VzmNz1/HXuWvZuqucIT078OPzRnLRCQPoktsu6vCcc65BMmsxQwfNKj8/3+bPnx91GLyzfjsPv7mG5xYHXVmnj+jN1Z8awul5vcnI8FN/nXPpRdICM/vE9YFp1c3VVlRUxXj+3U08/OYaFq3fTqecLC4/aRBXjRvMsN7eleWca3k8maTYUws2cMe/3qd4VznDenXkJ58fxYWj+9PZu7Kccy2YJ5MUemXZFm78xzscP6g7v7rkWE49rJd3ZTnnWgVPJiny3gc7+cbf3mZUv648eu0YOmT7R++caz3S6gr41mrrzv1cO30eXdu344Ep+Z5InHOtjn+rJdneiiqunT6fHfsq+ft14+jTJTfqkJxzrtn5kUkSxWLGd55YxNIPdnDPZcczql/XqENyzrmk8GSSRL+Y+T4zl27hB58dyZlH9ok6HOecSxpPJkny+Fvr+NN/irhi7CCu+dSQqMNxzrmk8mSSBG8WlvCDZ5Zwal4vbjtvlC9i5Zxr9TyZNLPCrbu57i8LGNa7I3/44miyMv0jds61fv5N14zK9lTwpYfnkZ2VwYNTTvQJGp1zbYafGtxMyquq+eqj89m8cz+PTx3LwB4dog7JOedSxo9MmoGZcfNT7zJvzTZ+fcmxjB7UPeqQnHMupTyZNIN7Xi3k6bc38t2zRnDesf2iDsc551LOk0mCnl20kbteWsGFx/fnhs8cFnU4zjkXiUiSiaRLJC2VFJOUX6vuGEkFYf27knLD8hPCx4WS7lZ4vq2kHpJekrQyvE9ZH9OCtWXc+I/FjBnSg/+56Gg/Bdg512ZFdWSyBLgQmBVfGK71/hfgOjMbBZwBVIbV9wJfAfLC24Sw/GbgFTPLA14JHyfd+rK9TH1kAf265vKnK08gJyszFbt1zrm0FEkyMbNlZra8jqqzgcVm9k7YrtTMqiUdCnQxszkWrDP8CHB++JxJwPRwe3pcedLs2FfJNQ/Po7I6xoNXn0j3jtnJ3qVzzqW1dBszGQGYpJmSFkr6fljeH9gQ125DWAbQx8w2hdubgXonwZI0VdJ8SfOLi4ubFGBldYwbHlvImpI93HflCQz3ZXadcy5515lIehnoW0fVrWb2bAPxnAKcCOwFXpG0ANjRmH2amUmyBuqnAdMA8vPz623XwPP58YylzF5Zwp0XH8PJw3sd7Es451yrlLRkYmbjm/C0DcAsMysBkPQ8MJpgHGVAXLsBwMZwe4ukQ81sU9gdtjWBsA9oeO9OXP/p4Xwhf2Ayd+Occy1KunVzzQSOltQhHIw/HXgv7MbaKWlseBbXVUDN0c0MYEq4PSWuvNlJ4tpThnLjOUckaxfOOdciRXVq8AWSNgDjgOckzQQws23AXcA8YBGw0MyeC5/2deABoBBYBbwQlt8BnCVpJTA+fOyccy6FFJwc1fbk5+fb/Pnzow7DOedaFEkLzCy/dnm6dXM555xrgTyZOOecS5gnE+eccwnzZOKccy5hnkycc84lzJOJc865hLXZU4MlFQNrm/j0XkBJM4bT3Dy+xHh8ifH4EpPu8Q02s961C9tsMkmEpPl1nWedLjy+xHh8ifH4EpPu8dXHu7mcc84lzJOJc865hHkyaZppUQdwAB5fYjy+xHh8iUn3+OrkYybOOecS5kcmzjnnEubJxDnnXMI8mTRA0gRJyyUVSrq5jvocSU+E9XMlDUlhbAMl/VvSe5KWSvpWHW3OkLRD0qLw9qNUxRfuf42kd8N9f2K+fwXuDj+/xZJGpzC2w+M+l0WSdkr6dq02Kf38JD0kaaukJXFlPSS9JGlleN+9nudOCduslDSlrjZJiu+Xkt4Pf35PS+pWz3Mb/F1IYny3SdoY9zM8t57nNvi3nsT4noiLbY2kRfU8N+mfX8LMzG913IBMgkW4hgHZwDvAyFptvg7cF25PBp5IYXyHAqPD7c7AijriOwP4vwg/wzVArwbqzyVY5EzAWGBuhD/rzQQXY0X2+QGnESxTvSSu7E7g5nD7ZuAXdTyvB1AU3ncPt7unKL6zgaxw+xd1xdeY34Ukxncb8L1G/Pwb/FtPVny16n8N/Ciqzy/Rmx+Z1G8MUGhmRWZWATwOTKrVZhIwPdz+B3BmuKxw0pnZJjNbGG7vApYB/VOx72Y0CXjEAnOAbpIOjSCOM4FVZtbUGRGahZnNAspqFcf/jk0Hzq/jqecAL5lZmQWrlb4ETEhFfGb2oplVhQ/nAAOae7+NVc/n1xiN+VtPWEPxhd8bXwD+1tz7TRVPJvXrD6yPe7yBT35Zf9gm/IPaAfRMSXRxwu6144G5dVSPk/SOpBckjUppYGDAi5IWSJpaR31jPuNUmEz9f8RRfn4AfcxsU7i9GehTR5t0+Ry/xEfLadd2oN+FZLoh7IZ7qJ5uwnT4/E4FtpjZynrqo/z8GsWTSQsnqRPwFPBtM9tZq3ohQdfNscA9wDMpDu8UMxsNTASul3Raivd/QJKygc8Df6+jOurP72Ms6O9Iy3P5Jd0KVAF/radJVL8L9wLDgeOATQRdSenoMho+Kkn7vyVPJvXbCAyMezwgLKuzjaQsoCtQmpLogn22I0gkfzWzf9auN7OdZrY73H4eaCepV6riM7ON4f1W4GmC7oR4jfmMk20isNDMttSuiPrzC22p6foL77fW0SbSz1HS1cDngC+GCe8TGvG7kBRmtsXMqs0sBtxfz36j/vyygAuBJ+prE9XndzA8mdRvHpAnaWj43+tkYEatNjOAmjNnLgZere+PqbmFfawPAsvM7K562vStGcORNIbg552SZCepo6TONdsEA7VLajWbAVwVntU1FtgR16WTKvX+Rxjl5xcn/ndsCvBsHW1mAmdL6h5245wdliWdpAnA94HPm9neeto05nchWfHFj8FdUM9+G/O3nkzjgffNbENdlVF+fgcl6jMA0vlGcLbRCoIzPW4Ny35K8IcDkEvQPVIIvAUMS2FspxB0eSwGFoW3c4HrgOvCNjcASwnOTpkDnJzC+IaF+30njKHm84uPT8Afws/3XSA/xT/fjgTJoWtcWWSfH0FS2wRUEvTbX0swBvcKsBJ4GegRts0HHoh77pfC38NC4JoUxldIMN5Q8ztYc3ZjP+D5hn4XUhTfo+Hv1mKCBHFo7fjCx5/4W09FfGH5wzW/c3FtU/75JXrz6VScc84lzLu5nHPOJcyTiXPOuYR5MnHOOZcwTybOOecS5snEOedcwjyZOJcASdW1Zh9ucMZZSddJuqoZ9rsmggsonauXnxrsXAIk7TazThHsdw3BdTklqd63c3XxIxPnkiA8crgzXIPiLUmHheW3SfpeuP1NBevRLJb0eFjWQ9IzYdkcSceE5T0lvahg7ZoHCC74rNnXFeE+Fkn6k6TM8PawpCVhDN+J4GNwbYgnE+cS075WN9elcXU7zOxo4PfAb+t47s3A8WZ2DMGV9wA/Ad4Oy/4beCQs/zHwupmNIpibaRCApCOBS4FPmdlxQDXwRYKJDfub2VFhDH9uxvfs3CdkRR2Acy3cvvBLvC5/i7v/TR31i4G/SnqGj2YkPgW4CMDMXg2PSLoQLKx0YVj+nKRtYfszgROAeeE0Yu0JJoP8X2CYpHuA54AXm/4WnTswPzJxLnmsnu0anyWYm2w0QTJoyj93Aqab2XHh7XAzu82CRbKOBV4jOOp5oAmv7VyjeTJxLnkujbsviK+QlAEMNLN/AzcRLF/QCZhN0E2FpDOAEgvWqZkFXB6WTyRYnheCSSAvlnRIWNdD0uDwTK8MM3sK+AFBwnIuabyby7nEtJe0KO7xv8ys5vTg7pIWA+UEU93HywT+IqkrwdHF3Wa2XdJtwEPh8/by0fTzPwH+Jmkp8CawDsDM3pP0A4JV+DIIZqS9HtgH/DksA7il+d6yc5/kpwY7lwR+6q5ra7ybyznnXML8yMQ551zC/MjEOedcwjyZOOecS5gnE+eccwnzZOKccy5hnkycc84l7P8DJ679CZQ+algAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "real_ratio = 0.5\n",
    "env_name = 'Pendulum-v0'\n",
    "env = gym.make(env_name)\n",
    "num_episodes = 20\n",
    "actor_lr = 5e-4\n",
    "critic_lr = 5e-3\n",
    "alpha_lr = 1e-3\n",
    "hidden_dim = 128\n",
    "gamma = 0.98\n",
    "tau = 0.005  # 软更新参数\n",
    "buffer_size = 10000\n",
    "target_entropy = -1\n",
    "model_alpha = 0.01  # 模型损失函数中的加权权重\n",
    "state_dim = env.observation_space.shape[0]\n",
    "action_dim = env.action_space.shape[0]\n",
    "action_bound = env.action_space.high[0]  # 动作最大值\n",
    "\n",
    "rollout_batch_size = 1000\n",
    "rollout_length = 1  # 推演长度k,推荐更多尝试\n",
    "model_pool_size = rollout_batch_size * rollout_length\n",
    "\n",
    "agent = SAC(state_dim, hidden_dim, action_dim, action_bound, actor_lr,\n",
    "            critic_lr, alpha_lr, target_entropy, tau, gamma)\n",
    "model = EnsembleDynamicsModel(state_dim, action_dim, model_alpha)\n",
    "fake_env = FakeEnv(model)\n",
    "env_pool = ReplayBuffer(buffer_size)\n",
    "model_pool = ReplayBuffer(model_pool_size)\n",
    "mbpo = MBPO(env, agent, fake_env, env_pool, model_pool, rollout_length,\n",
    "            rollout_batch_size, real_ratio, num_episodes)\n",
    "\n",
    "return_list = mbpo.train()\n",
    "\n",
    "episodes_list = list(range(len(return_list)))\n",
    "plt.plot(episodes_list, return_list)\n",
    "plt.xlabel('Episodes')\n",
    "plt.ylabel('Returns')\n",
    "plt.title('MBPO on {}'.format(env_name))\n",
    "plt.show()\n",
    "\n",
    "# episode: 1, return: -1083\n",
    "# episode: 2, return: -1324\n",
    "# episode: 3, return: -979\n",
    "# episode: 4, return: -130\n",
    "# episode: 5, return: -246\n",
    "# episode: 6, return: -2\n",
    "# episode: 7, return: -239\n",
    "# episode: 8, return: -2\n",
    "# episode: 9, return: -122\n",
    "# episode: 10, return: -236\n",
    "# episode: 11, return: -238\n",
    "# episode: 12, return: -2\n",
    "# episode: 13, return: -127\n",
    "# episode: 14, return: -128\n",
    "# episode: 15, return: -125\n",
    "# episode: 16, return: -124\n",
    "# episode: 17, return: -125\n",
    "# episode: 18, return: -247\n",
    "# episode: 19, return: -127\n",
    "# episode: 20, return: -129"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第17章-基于模型的策略优化.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
